package com.yihecode.camera.ai.service;

import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.yihecode.camera.ai.entity.Model;
import com.yihecode.camera.ai.entity.ModelDepend;
import com.yihecode.camera.ai.exception.BizException;
import com.yihecode.camera.ai.mapper.ModelMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import java.io.File;
import java.util.*;

/**
 * 模型管理
 */
@Service
public class ModelServiceImpl extends ServiceImpl<ModelMapper, Model> implements ModelService {

    //
    @Autowired
    private ModelDependService modelDependService;

    //
    @Value("${modelDir}")
    public String modelDir;

    /**
     * 根据onnx md5值查询
     *
     * @param md5
     * @return
     */
    @Override
    public Model getByOnnxMd5(String md5) {
        LambdaQueryWrapper<Model> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(Model::getOnnxMd5, md5);
        return this.getOne(queryWrapper);
    }

    /**
     * 根据文件名称查询
     *
     * @param fileName
     * @return
     */
    @Override
    public Model getByOnnxName(String fileName) {
        LambdaQueryWrapper<Model> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(Model::getOnnxName, fileName);
        return this.getOne(queryWrapper);
    }

    /**
     * 分页查询
     *
     * @param pageObj
     * @return
     */
    @Override
    public IPage<Model> listPage(IPage<Model> pageObj) {
        LambdaQueryWrapper<Model> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(Model::getState, 0);
        queryWrapper.orderByDesc(Model::getCreatedAt);
        return this.page(pageObj, queryWrapper);
    }

    /**
     * 查询数据列表
     *
     * @return
     */
    @Override
    public List<Model> listData() {
        LambdaQueryWrapper<Model> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(Model::getState, 0);
        queryWrapper.orderByAsc(Model::getCreatedAt);
        //
        List<Model> modelList = this.list(queryWrapper);
        if(modelList == null) {
            return new ArrayList<>();
        }
        return modelList;
    }

    /**
     * 根据模型名称查询数量
     *
     * @param name
     * @return
     */
    @Override
    public int getActiveCountByName(String name) {
        LambdaQueryWrapper<Model> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(Model::getName, name);
        queryWrapper.eq(Model::getState, 0);
        return this.count(queryWrapper);
    }

    /**
     * 根据模型名称查询版本数量
     *
     * @param name
     * @return
     */
    @Override
    public int getVersionCountByName(String name) {
        LambdaQueryWrapper<Model> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(Model::getName, name);
        return this.count(queryWrapper);
    }

    /**
     * 更新版本数量
     *
     * @param name
     * @param newVersionCount
     */
    @Override
    public void updateVersionCount(String name, int newVersionCount) {
        LambdaUpdateWrapper<Model> updateWrapper = new LambdaUpdateWrapper<>();
        updateWrapper.eq(Model::getName, name);
        updateWrapper.set(Model::getVersionCount, newVersionCount);
        this.getBaseMapper().update(null, updateWrapper);
    }

    /**
     * 保存模型
     *
     * @param model
     * @throws Exception
     */
    @Override
    public Map<String, Object> saveModel(Model model) throws Exception {
        //
        Map<String, Object> retMap = new HashMap<>();
        retMap.put("msgType", 20001);
        retMap.put("msgText", "操作成功");
        retMap.put("modelId", model.getId());

        //
        if(model.getModelIds() != null) {
            for(Long dependModelId : model.getModelIds()) {
                Model dependModel = this.getById(dependModelId);
                if(dependModel.getName().equals(model.getName())) {
                    throw new BizException("不能依赖同名模型");
                }
            }
        }

        //
        if(model.getId() == null) {
            //
            int activeCount = this.getActiveCountByName(model.getName());

            //
            model.setVersionCount(activeCount > 0 ? 1 : 0); // 默认关闭状态
            model.setState(0);
            model.setCreatedAt(new Date());
            model.setVersionCount(0);
            if(StrUtil.isNotBlank(model.getOnnxName())) {
                File onnxFile = new File(modelDir + model.getOnnxName());
                if(onnxFile.exists()) {
                    model.setOnnxSize(onnxFile.length());
                } else {
                    model.setOnnxSize(0l);
                }
            } else {
                model.setOnnxSize(0l);
            }
            this.save(model);
            //
            retMap.put("modelId", model.getId());

            //
            int newVersionCount = this.getVersionCountByName(model.getName());
            this.updateVersionCount(model.getName(), newVersionCount);

            //
            if(activeCount > 0) {
                retMap.put("msgType", 20002);
                retMap.put("msgText", "操作成功，当前模型默认关闭状态，是否启用");
            }
        } else {
            Model modelDb = this.getById(model.getId());
            if(modelDb == null) {
                throw new BizException("模型不存在");
            }
            //
            if(!modelDb.getName().equals(model.getName())) {
                throw new BizException("模型名称不能修改");
            }
            //
            this.saveOrUpdate(model);
        }

        //
        Long modelId = model.getId();
        //
        modelDependService.removeByModel(modelId);
        //
        if(model.getModelIds() != null) {
            for(Long dependModelId : model.getModelIds()) {
                ModelDepend modelDepend = new ModelDepend();
                modelDepend.setModelId(modelId);
                modelDepend.setDependModelId(dependModelId);
                modelDependService.save(modelDepend);
            }
        }
        return retMap;
    }

    /**
     * 模型启用
     *
     * @param modelId
     */
    @Override
    public void updateModelEnable(Long modelId) throws Exception {
        //
        Model model = this.getById(modelId);
        if(model == null) {
            throw new BizException("找不到模型");
        }
        //
        if(model.getState() == 0) {
            throw new BizException("模型已经为启用状态");
        }
        //
        String modelName = model.getName();
        //
        LambdaUpdateWrapper<Model> updateWrapper = new LambdaUpdateWrapper<>();
        updateWrapper.set(Model::getState, 1);
        updateWrapper.eq(Model::getName, modelName);
        updateWrapper.eq(Model::getState, 0);
        this.update(null, updateWrapper);

        //
        Model updateModel = new Model();
        updateModel.setId(modelId);
        updateModel.setState(0);
        this.updateById(updateModel);
    }

    /**
     * 查询模型版本
     *
     * @param modelId
     * @return
     */
    @Override
    public List<Model> listVersion(Long modelId) {
        //
        Model model = this.getById(modelId);
        if(model == null) {
            return new ArrayList<>();
        }
        //
        String modelName = model.getName();
        //
        LambdaQueryWrapper<Model> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(Model::getName, modelName);
        queryWrapper.orderByDesc(Model::getCreatedAt);
        List<Model> modelList = this.list(queryWrapper);
        if(modelList == null) {
            return new ArrayList<>();
        }
        return modelList;
    }
}